{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 通过极简方案构建手写数字识别模型\n",
    "\n",
    "上一节介绍了创新性的“横纵式”教学法，有助于深度学习初学者快速掌握深度学习理论知识，并在过程中让读者获得真实建模的实战体验。在“横纵式”教学法中，纵向概要介绍模型的基本代码结构和极简实现方案，如 **图1** 所示。本节将使用这种极简实现方案快速完成手写数字识别的建模。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/762c127363684c32832cb61b5d6deaa013023131a36948b6b695cec2df72f791\" width=\"1000\" hegiht=\"\" ></center>\n",
    "<center><br>图1：“横纵式”教学法—纵向极简实现方案</br></center>\n",
    "<br></br>\n",
    "\n",
    "### 前提条件\n",
    "\n",
    "在数据处理前，首先要加载飞桨平台与“手写数字识别”模型相关的类库，实现方法如下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  from collections import MutableMapping\n",
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  from collections import Iterable, Mapping\n",
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  from collections import Sized\n"
     ]
    }
   ],
   "source": [
    "#加载飞桨和相关类库\r\n",
    "import paddle\r\n",
    "from paddle.nn import Linear\r\n",
    "import paddle.nn.functional as F\r\n",
    "import os\r\n",
    "import numpy as np\r\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 数据处理\n",
    "\n",
    "飞桨提供了多个封装好的数据集API，涵盖计算机视觉、自然语言处理、推荐系统等多个领域，帮助读者快速完成深度学习任务。如在手写数字识别任务中，通过[paddle.vision.datasets.MNIST](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/vision/datasets/mnist/MNIST_cn.html)可以直接获取处理好的MNIST训练集、测试集，飞桨API支持如下常见的学术数据集：\n",
    "\n",
    "* mnist\n",
    "* cifar\n",
    "* Conll05\n",
    "* imdb\n",
    "* imikolov\n",
    "* movielens\n",
    "* sentiment\n",
    "* uci_housing\n",
    "* wmt14\n",
    "* wmt16\n",
    "\n",
    "通过paddle.vision.datasets.MNIST API设置数据读取器，代码如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 设置数据读取器，API自动读取MNIST数据训练集\n",
    "train_dataset = paddle.vision.datasets.MNIST(mode='train')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " 通过如下代码读取任意一个数据内容，观察打印结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  if isinstance(obj, collections.Iterator):\n",
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  return list(data) if isinstance(data, collections.MappingView) else data\n",
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
      "  'a.item() instead', DeprecationWarning, stacklevel=1)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACcCAYAAACUcfL+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAACmFJREFUeJztnW+MFdUZxn8PK6ygINClBlbCGhaJtIlA1lpDo+s/QFMkfGkQA0psShoQm0IQSlpIY1KrTZsgJK21BC2VxqqINDYEDHwgtoTdQHBRV8ECruVvLGKxqUBPP9xhO2fi7t6999yZe/e+v+RmzzNn7sy77MOZd8659x055zCMYumXdQBG38CMZATBjGQEwYxkBMGMZATBjGQEoaqNJOmgpOas4+gLyOaRjBBU9YhkhKOqjSTpiKS7Ja2W9CdJGyV9JultSTdIWiHplKSPJE2NvW++pHejfT+UtCBx3GWSjkv6h6TvSnKSGqO+Wkm/kHRM0klJv5Y0MO3fPTRVbaQEM4DfA8OAfcA2cv8+9cBPgd/E9j0FfBsYAswHfiVpMoCk6cAPgbuBRqA5cZ4ngRuAiVF/PfCTUvxCqeKcq9oXcITcH3w1sD22fQbwL6Am0oMBBwzt4jivAY9F7fXAz2J9jdF7GwEB54Gxsf5bgb9n/W9R7OuKFLxaKZyMtf8NnHHOXYppgKuBs5LuBVaRG1n6AYOAt6N9RgEtsWN9FGuPiPZtlXR5m4CaQL9DZpiReomkWuAVYB6wxTl3QdJr5AwBcBy4LvaW0bH2GXKm/Jpz7uM04k0Ly5F6zwCgFjgNXIxGp6mx/peA+ZJulDQI+PHlDufcf4HfksupvgogqV7StNSiLxFmpF7inPsMWEzOMP8E5gCvx/r/AqwBdgKHgL9FXf+Jfj5+ebukc8AOYHwqwZcQm5AsMZJuBNqAWufcxazjKRU2IpUASbOi+aJhwM+BrX3ZRGBGKhULyM01HQYuAd/PNpzSY5c2IwhFjUiSpktql3RI0vJQQRmVR8EjkqQa4H3gHqAD2As84Jx7J1x4RqVQzITkN4BDzrkPAST9EZgJdGmkuro619DQUMQpjbRpbW0945wb0dN+xRipHn/6vwO4pbs3NDQ00NLS0t0uRpkh6Wg++5X8rk3S9yS1SGo5ffp0qU9nZEQxRvoYfx3pumibh3PuWedck3OuacSIHkdIo0Ipxkh7gXGSrpc0AJhNbKnAqC4KzpGccxclLSL3AbAaYL1z7mCwyIyKoqiPkTjn3gDeCBSLUcHYEokRBDOSEQQzkhEEM5IRBDOSEQQzkhEEM5IRBDOSEQQzkhEEM5IRBDOSEQQzkhEE++5/nly6dMnTn376ad7vXbt2rac///xzT7e3t3t63bp1nl66dKmnN23a5Okrr7zS08uX//97GKtWrco7zmKwEckIghnJCIIZyQhC1eRIx44d8/QXX3zh6bfeesvTu3fv9vTZs2c9/fLLLweLbfTo0Z5+9NFHPb1582ZPDx482NM33XSTp2+//fZgseWLjUhGEMxIRhDMSEYQ+myOtG/fPk/feeednu7NPFBoamr82qNPPPGEp6+66ipPP/jgg54eNWqUp4cNG+bp8ePTLwBnI5IRBDOSEQQzkhGEPpsjjRkzxtN1dXWeDpkj3XKLX4QlmbPs3LnT0wMGDPD03Llzg8WSFTYiGUEwIxlBMCMZQeizOdLw4cM9/fTTT3t669atnp40aZKnFy9e3O3xJ06c2NnesWOH15ecB2pra/P0mjVruj12JWIjkhGEHo0kaX30FMW22LbhkrZL+iD6Oay7Yxh9n3xGpA3A9MS25cCbzrlxwJuRNqqYvOpsS2oA/uyc+3qk24Fm59xxSSOBXc65Hhd4mpqaXLlUtT137pynk5/xWbDAe0wtzz33nKc3btzY2Z4zZ07g6MoHSa3Ouaae9is0R7rWOXc8ap8Ari3wOEYfoehk2+WGtC6HNSuPXB0UaqST0SWN6Oeprna08sjVQaHzSK8DD5F79PhDwJZgEaXEkCFDuu2/5ppruu2P50yzZ8/2+vr1q75ZlXxu/zcBfwXGS+qQ9Ag5A90j6QNyjzt/srRhGuVOjyOSc+6BLrruChyLUcFU3xhslIQ+u9ZWLKtXr/Z0a2urp3ft2tXZTq61TZ06lWrDRiQjCGYkIwhmJCMIqT5lu5zW2nrL4cOHPT158uTO9tChQ72+O+64w9NNTf5S1cKFCz0tKUSIJaHUa22G4WFGMoJgt/95MnbsWE9v2LChsz1//nyv74UXXuhWnz9/3tPz5s3z9MiRIwsNMzNsRDKCYEYygmBGMoJgOVKBzJo1q7Pd2Njo9S1ZssTTySWUFStWePro0aOeXrlypafr6+sLjjMtbEQygmBGMoJgRjKCYEskJSBZSjn59fCHH37Y08m/wV13+Z8Z3L59e7jgeoktkRipYkYygmBGMoJgOVIG1NbWevrChQue7t+/v6e3bdvm6ebm5pLE9WVYjmSkihnJCIIZyQiCrbUF4MCBA55OPoJr7969nk7mREkmTJjg6dtuu62I6NLBRiQjCGYkIwhmJCMIliPlSfKR6s8880xn+9VXX/X6Tpw40atjX3GF/2dIfma7EsrklH+ERkWQT32k0ZJ2SnpH0kFJj0XbrUSy0Uk+I9JFYIlzbgLwTWChpAlYiWQjRj6Fto4Dx6P2Z5LeBeqBmUBztNvzwC7g8ZJEmQLJvObFF1/09Nq1az195MiRgs918803ezr5Ge3777+/4GNnRa9ypKje9iRgD1Yi2YiRt5EkXQ28AvzAOedVO++uRLKVR64O8jKSpP7kTPQH59zle928SiRbeeTqoMccSbmaK78D3nXO/TLWVVElkk+ePOnpgwcPenrRokWefu+99wo+V/LRpMuWLfP0zJkzPV0J80Q9kc+E5BRgLvC2pP3Rth+RM9BLUbnko8B3ShOiUQnkc9e2G+iqEpSVSDYAm9k2AtFn1to++eQTTycfk7V//35PJ0v59ZYpU6Z0tpPf9Z82bZqnBw4cWNS5KgEbkYwgmJGMIJiRjCBUVI60Z8+ezvZTTz3l9SU/F93R0VHUuQYNGuTp5OPb4+tjycezVyM2IhlBMCMZQaioS9vmzZu/tJ0Pya/4zJgxw9M1NTWeXrp0qaeT1f0NHxuRjCCYkYwgmJGMIFhZG6NbrKyNkSpmJCMIZiQjCGYkIwhmJCMIZiQjCGYkIwhmJCMIZiQjCGYkIwhmJCMIqa61STpN7lu5dcCZ1E7cO8o1tqziGuOc67FoQ6pG6jyp1JLPQmAWlGts5RrXZezSZgTBjGQEISsjPZvRefOhXGMr17iAjHIko+9hlzYjCKkaSdJ0Se2SDknKtJyypPWSTklqi20ri9rhlVjbPDUjSaoB1gH3AhOAB6J63VmxAZie2FYutcMrr7a5cy6VF3ArsC2mVwAr0jp/FzE1AG0x3Q6MjNojgfYs44vFtQW4p1zjc86lemmrBz6K6Y5oWzlRdrXDK6W2uSXbXeBy/+0zvaUttLZ5FqRppI+B0TF9XbStnMirdngaFFPbPAvSNNJeYJyk6yUNAGaTq9VdTlyuHQ4Z1g7Po7Y5lFtt85STxvuA94HDwMqME9hN5B7Wc4FcvvYI8BVyd0MfADuA4RnF9i1yl60DwP7odV+5xPdlL5vZNoJgybYRBDOSEQQzkhEEM5IRBDOSEQQzkhEEM5IRBDOSEYT/AefqSFIluHjbAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 144x144 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图像数据形状和对应数据为: (28, 28)\n",
      "图像标签形状和对应数据为: (1,) [5]\n",
      "\n",
      "打印第一个batch的第一个图像，对应标签数字为[5]\n"
     ]
    }
   ],
   "source": [
    "train_data0 = np.array(train_dataset[0][0])\r\n",
    "train_label_0 = np.array(train_dataset[0][1])\r\n",
    "\r\n",
    "# 显示第一batch的第一个图像\r\n",
    "import matplotlib.pyplot as plt\r\n",
    "plt.figure(\"Image\") # 图像窗口名称\r\n",
    "plt.figure(figsize=(2,2))\r\n",
    "plt.imshow(train_data0, cmap=plt.cm.binary)\r\n",
    "plt.axis('on') # 关掉坐标轴为 off\r\n",
    "plt.title('image') # 图像题目\r\n",
    "plt.show()\r\n",
    "\r\n",
    "print(\"图像数据形状和对应数据为:\", train_data0.shape)\r\n",
    "print(\"图像标签形状和对应数据为:\", train_label_0.shape, train_label_0)\r\n",
    "print(\"\\n打印第一个batch的第一个图像，对应标签数字为{}\".format(train_label_0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "使用matplotlib工具包将其显示出来，如**图2** 所示。可以看到图片显示的数字是5，和对应标签数字一致。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/a07d9b3b5839434e98afe05a298d3ce1c9b6cbc02124488a9bd8b7c2efeb42c4\" width=\"300\" hegiht=\"\" ></center>\n",
    "<center><br>图2：matplotlib打印结果示意图</br></center>\n",
    "<br></br>\n",
    "\n",
    "------\n",
    "**说明：**\n",
    "\n",
    "飞桨将维度是28×28的手写数字图像转成向量形式存储，因此使用飞桨数据加载器读取到的手写数字图像是长度为784（28×28）的向量。\n",
    "\n",
    "------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 飞桨API的使用方法\n",
    "\n",
    "熟练掌握飞桨API的使用方法，是使用飞桨完成各类深度学习任务的基础，也是开发者必须掌握的技能。\n",
    "\n",
    "**飞桨API文档获取方式及目录结构**\n",
    "\n",
    "登录“[飞桨官网->文档->API文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/index_cn.html)”，可以获取飞桨API文档。在飞桨最新的版本中，对API做了许多优化，目录结构与说明，如 **图3** 所示。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/316984568d8e4e189fe3449108fa1d76a7d82330834f41139f2aaba8f745d49a\" width=\"900\" hegiht=\"\" ></center>\n",
    "<center><br>图3：飞桨API文档目录</br></center>\n",
    "<br></br>\n",
    "\n",
    "**API文档使用方法**\n",
    "\n",
    "飞桨每个API的文档结构一致，包含接口形式、功能说明和计算公式、参数和返回值、代码示例四个部分。 以Relu函数为例，API文档结构如 **图4** 所示。通过飞桨API文档，读者不仅可以详细查看函数功能，还可以通过可运行的代码示例来实践API的使用。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/badc3b56be924955b97cc30d253eb4c850582d8e94004d5c876cf17eac8aee15\" width=\"700\" hegiht=\"\" ></center>\n",
    "<center><br>图4：Relu的API文档</br></center>\n",
    "<br></br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 模型设计\n",
    "\n",
    "在房价预测深度学习任务中，我们使用了单层且没有非线性变换的模型，取得了理想的预测效果。在手写数字识别中，我们依然使用这个模型预测输入的图形数字值。其中，模型的输入为784维（28×28）数据，输出为1维数据，如 **图5** 所示。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/9c146e7d9c4a4119a8cd09f7c8b5ee61f2ac1820a221429a80430291728b9c4a\" width=\"400\" hegiht=\"\" ></center>\n",
    "<center><br>图5：手写数字识别网络模型</br></center>\n",
    "<br></br>\n",
    "\n",
    "输入像素的位置排布信息对理解图像内容非常重要（如将原始尺寸为28×28图像的像素按照7×112的尺寸排布，那么其中的数字将不可识别），因此网络的输入设计为28×28的尺寸，而不是1×784，以便于模型能够正确处理像素之间的空间信息。\n",
    "\n",
    "------\n",
    "**说明：**\n",
    "\n",
    "事实上，采用只有一层的简单网络（对输入求加权和）时并没有处理位置关系信息，因此可以猜测出此模型的预测效果可能有限。在后续优化环节介绍的卷积神经网络则更好的考虑了这种位置关系信息，模型的预测效果也会有显著提升。\n",
    "\n",
    "------\n",
    "\n",
    "下面以类的方式组建手写数字识别的网络，实现方法如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 定义mnist数据识别网络结构，同房价预测网络\n",
    "class MNIST(paddle.nn.Layer):\n",
    "    def __init__(self):\n",
    "        super(MNIST, self).__init__()\n",
    "        \n",
    "        # 定义一层全连接层，输出维度是1\n",
    "        self.fc = paddle.nn.Linear(in_features=784, out_features=1)\n",
    "        \n",
    "    # 定义网络结构的前向计算过程\n",
    "    def forward(self, inputs):\n",
    "        outputs = self.fc(inputs)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 训练配置\n",
    "\n",
    "训练配置需要先生成模型实例（设为“训练”状态），再设置优化算法和学习率（使用随机梯度下降SGD，学习率设置为0.001），实现方法如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 声明网络结构\r\n",
    "model = MNIST()\r\n",
    "\r\n",
    "def train(model):\r\n",
    "    # 启动训练模式\r\n",
    "    model.train()\r\n",
    "    # 加载训练集 batch_size 设为 16\r\n",
    "    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), \r\n",
    "                                        batch_size=16, \r\n",
    "                                        shuffle=True)\r\n",
    "    # 定义优化器，使用随机梯度下降SGD优化器，学习率设置为0.001\r\n",
    "    opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 训练过程\n",
    "\n",
    "训练过程采用二层循环嵌套方式，训练完成后需要保存模型参数，以便后续使用。\n",
    "\n",
    "- 内层循环：负责整个数据集的一次遍历，遍历数据集采用分批次（batch）方式。\n",
    "- 外层循环：定义遍历数据集的次数，本次训练中外层循环10次，通过参数EPOCH_NUM设置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 图像归一化函数，将数据范围为[0, 255]的图像归一化到[0, 1]\n",
    "def norm_img(img):\n",
    "    # 验证传入数据格式是否正确，img的shape为[batch_size, 28, 28]\n",
    "    assert len(img.shape) == 3\n",
    "    batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]\n",
    "    # 归一化图像数据\n",
    "    img = img / 255\n",
    "    # 将图像形式reshape为[batch_size, 784]\n",
    "    img = paddle.reshape(img, [batch_size, img_h*img_w])\n",
    "    \n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch_id: 0, batch_id: 0, loss is: [35.525185]\n",
      "epoch_id: 0, batch_id: 1000, loss is: [7.4399786]\n",
      "epoch_id: 0, batch_id: 2000, loss is: [2.0210705]\n",
      "epoch_id: 0, batch_id: 3000, loss is: [2.325027]\n",
      "epoch_id: 1, batch_id: 0, loss is: [2.4414306]\n",
      "epoch_id: 1, batch_id: 1000, loss is: [4.6318164]\n",
      "epoch_id: 1, batch_id: 2000, loss is: [4.6807127]\n",
      "epoch_id: 1, batch_id: 3000, loss is: [5.7014084]\n",
      "epoch_id: 2, batch_id: 0, loss is: [3.4229655]\n",
      "epoch_id: 2, batch_id: 1000, loss is: [2.1136832]\n",
      "epoch_id: 2, batch_id: 2000, loss is: [2.3517294]\n",
      "epoch_id: 2, batch_id: 3000, loss is: [6.7515297]\n",
      "epoch_id: 3, batch_id: 0, loss is: [4.119179]\n",
      "epoch_id: 3, batch_id: 1000, loss is: [4.4800296]\n",
      "epoch_id: 3, batch_id: 2000, loss is: [3.4902763]\n",
      "epoch_id: 3, batch_id: 3000, loss is: [3.631486]\n",
      "epoch_id: 4, batch_id: 0, loss is: [6.123066]\n",
      "epoch_id: 4, batch_id: 1000, loss is: [2.8558893]\n",
      "epoch_id: 4, batch_id: 2000, loss is: [2.6112337]\n",
      "epoch_id: 4, batch_id: 3000, loss is: [2.0097098]\n",
      "epoch_id: 5, batch_id: 0, loss is: [3.9023933]\n",
      "epoch_id: 5, batch_id: 1000, loss is: [2.1165676]\n",
      "epoch_id: 5, batch_id: 2000, loss is: [3.2067215]\n",
      "epoch_id: 5, batch_id: 3000, loss is: [2.4574804]\n",
      "epoch_id: 6, batch_id: 0, loss is: [1.8463242]\n",
      "epoch_id: 6, batch_id: 1000, loss is: [3.4741895]\n",
      "epoch_id: 6, batch_id: 2000, loss is: [2.057652]\n",
      "epoch_id: 6, batch_id: 3000, loss is: [2.0860665]\n",
      "epoch_id: 7, batch_id: 0, loss is: [3.90655]\n",
      "epoch_id: 7, batch_id: 1000, loss is: [2.5527935]\n",
      "epoch_id: 7, batch_id: 2000, loss is: [3.239427]\n",
      "epoch_id: 7, batch_id: 3000, loss is: [6.7344103]\n",
      "epoch_id: 8, batch_id: 0, loss is: [1.6209174]\n",
      "epoch_id: 8, batch_id: 1000, loss is: [2.686802]\n",
      "epoch_id: 8, batch_id: 2000, loss is: [7.759363]\n",
      "epoch_id: 8, batch_id: 3000, loss is: [3.1380877]\n",
      "epoch_id: 9, batch_id: 0, loss is: [3.1067057]\n",
      "epoch_id: 9, batch_id: 1000, loss is: [2.864774]\n",
      "epoch_id: 9, batch_id: 2000, loss is: [2.528369]\n",
      "epoch_id: 9, batch_id: 3000, loss is: [4.1854725]\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型\n",
    "paddle.vision.set_image_backend('cv2')\n",
    "\n",
    "# 声明网络结构\n",
    "model = MNIST()\n",
    "\n",
    "def train(model):\n",
    "    # 启动训练模式\n",
    "    model.train()\n",
    "    # 加载训练集 batch_size 设为 16\n",
    "    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), \n",
    "                                        batch_size=16, \n",
    "                                        shuffle=True)\n",
    "    # 定义优化器，使用随机梯度下降SGD优化器，学习率设置为0.001\n",
    "    opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())\n",
    "    EPOCH_NUM = 10\n",
    "    for epoch in range(EPOCH_NUM):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            images = norm_img(data[0]).astype('float32')\n",
    "            labels = data[1].astype('float32')\n",
    "            \n",
    "            #前向计算的过程\n",
    "            predicts = model(images)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = F.square_error_cost(predicts, labels)\n",
    "            avg_loss = paddle.mean(loss)\n",
    "            \n",
    "            #每训练了1000批次的数据，打印下当前Loss的情况\n",
    "            if batch_id % 1000 == 0:\n",
    "                print(\"epoch_id: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n",
    "            \n",
    "            #后向传播，更新参数的过程\n",
    "            avg_loss.backward()\n",
    "            opt.step()\n",
    "            opt.clear_grad()\n",
    "            \n",
    "train(model)\n",
    "paddle.save(model.state_dict(), './mnist.pdparams')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "另外，从训练过程中损失所发生的变化可以发现，虽然损失整体上在降低，但到训练的最后一轮，损失函数值依然较高。可以猜测手写数字识别完全复用房价预测的代码，训练效果并不好。接下来我们通过模型测试，获取模型训练的真实效果。\n",
    "\n",
    "# 模型测试\n",
    "\n",
    "模型测试的主要目的是验证训练好的模型是否能正确识别出数字，包括如下四步：\n",
    "\n",
    "* 声明实例\n",
    "* 加载模型：加载训练过程中保存的模型参数，\n",
    "* 灌入数据：将测试样本传入模型，模型的状态设置为校验状态（eval），显式告诉框架我们接下来只会使用前向计算的流程，不会计算梯度和梯度反向传播。\n",
    "* 获取预测结果，取整后作为预测标签输出。\n",
    "\n",
    "在模型测试之前，需要先从'./work/example_0.png'文件中读取样例图片，并进行归一化处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEUtJREFUeJzt3X2MVGWWBvDngIzKl9p2L0EH6NmBaJQgaAnGUYKZHSJIghPUgGbsVdzGZISdZEz8WD8TjEQFnEScyCihWWaZWTN0+NDsDktMyJh1YklclXF3dbEJINBNWhmJJm3D2T/6Ylrse96y3rp1b3ueX9Lp6jr11j1U9UNV13vvfUVVQUT+DMm7ASLKB8NP5BTDT+QUw0/kFMNP5BTDT+QUw0/kFMNP5BTDT+TUGfXcWGNjozY3N9dzk0SudHR04OjRo1LJbaPCLyLXA/gVgKEAXlTVFdbtm5ubUS6XYzaZmZjdnEXsx/rkyZNmfciQuDdg1v2H7jv07w7920Ks+w/dd5a9xd531o9btUqlUsW3rfq3TkSGAlgDYA6ASwAsEpFLqr0/IqqvmJec6QA+VNW9qtoD4HcA5temLSLKWkz4LwSwv9/PB5LrvkZEWkWkLCLlrq6uiM0RUS1l/mm/qq5V1ZKqlpqamrLeHBFVKCb8BwGM6/fz95PriGgQiAn/mwAmicgPROR7ABYC2Fqbtogoa1VP9alqr4jcA+Df0TfVt05V99SsszrLcmomdrottp7VWCD8uOU15QUAvb29qbUzzrB/9Ys6lVdLUfP8qvoqgFdr1AsR1RF37yVyiuEncorhJ3KK4SdyiuEncorhJ3Kqrsfzf1fFzpWHxsce8muJnacPHa5sjc9629Zcfuxh1t+F/QD4yk/kFMNP5BTDT+QUw0/kFMNP5BTDT+QUp/rqIHbKKkbWZw6OGR/bW8x0moepvBC+8hM5xfATOcXwEznF8BM5xfATOcXwEznF8BM5Vfd5/phVW6u930rEbDt2znfo0KFm/fjx42b98ccfT62tXLnSHDts2DCz3tPTY9avvfZas/7cc8+l1qZMmWKOpWzxlZ/IKYafyCmGn8gphp/IKYafyCmGn8gphp/Iqah5fhHpAPAZgBMAelW1VIumquzFrMfuBxAjdNz63r17zfrNN99s1vfsSV8Z/dJLLzXHtra2mvXXX3/drL/88stmfcmSJam1Z555xhx71VVXmfXQ/hFffvllai20f8N34Xj9kFrs5HOdqh6twf0QUR3xbT+RU7HhVwB/FJG3RMR+/0hEhRL7tv8aVT0oIn8DYIeI/Leq7up/g+Q/hVYAGD9+fOTmiKhWol75VfVg8r0TQDuA6QPcZq2qllS11NTUFLM5IqqhqsMvIiNEZNSpywBmA3ivVo0RUbZi3vaPAdCeTImcAeBfVPXfatIVEWWu6vCr6l4Al9Wwlyh5nmf9xIkTZj00V37nnXea9Y6ODrN+++23p9ZWrVpljj333HPN+tKlS836FVdcYdYffvjh1NqcOXPMsa+99lrUtq0lukN43n4i+s5i+ImcYviJnGL4iZxi+ImcYviJnKr7qbvzPD13DGs6r7e31xz76KOPmvV9+/aZ9RkzZpj1NWvWpNbOPvtsc2zslNann35q1q1Tf4fue8uWLWZ92rRpZj3P07EPBnzlJ3KK4SdyiuEncorhJ3KK4SdyiuEncorhJ3JqUC3RneXca2i+2zpN9BNPPGGO3bVrl1m/6KKLzPq2bdvMujWXn/Whqffdd59Ztw5HfuWVV8yxTz31lFmfOHGiWb/ppptSa8OHDzfHhvbdiDlcuCj4yk/kFMNP5BTDT+QUw0/kFMNP5BTDT+QUw0/k1KA6nt8SOn12aDnnmPnw9vb2qG0vXrzYrDc0NJj1GKHlw0NGjx5t1jdu3Jha27Rpkzn2rrvuiqpbvxN33HGHOTY0j28t/13JeEu9ziXAV34ipxh+IqcYfiKnGH4ipxh+IqcYfiKnGH4ip4KTkSKyDsA8AJ2qOjm5rgHA7wE0A+gAcIuqfpJdm2FZzuMDwEcffZRaO3TokDk2tBT13XffbdZDrDnnYcOGmWOznlO27v/WW281x06ZMsWsX3755WZ92bJlqbUxY8aYY+fOnWvWQ4/rYFDJK/96ANefdt39AHaq6iQAO5OfiWgQCYZfVXcB6D7t6vkA2pLLbQBurHFfRJSxav/mH6Oqp97rHgZgv4ciosKJ/sBP+/6YTv2DWkRaRaQsIuWurq7YzRFRjVQb/iMiMhYAku+daTdU1bWqWlLVUlNTU5WbI6Jaqzb8WwG0JJdbANjLqRJR4QTDLyKbAPwngItE5ICILAawAsBPROQDAH+X/ExEg0hwnl9VF6WUflzjXjI9x3zsfPYbb7yRWuvuPn0y5OtmzJhh1keMGFFVT6fEzDmHjucfMsR+fYh5zkLnxp88ebJZb2trM+tLlixJrS1fvtwce+WVV5r1xsZGs16vY/JjcA8/IqcYfiKnGH4ipxh+IqcYfiKnGH4ipwq1znBoeiRmee/QlNQXX3xh1p999tmq73vBggVmPSRmOi00NjSVFxK6f0vsMteLFqXNQvexTqm+efNmc2zo1N7bt2836zGnRI99TireTl22QkSFw/ATOcXwEznF8BM5xfATOcXwEznF8BM5Vah5/izns0P7AXz88cdmfffu3am10Jzu+PHjzXrMXHlI1oeWZjknHXu48fr161Nrn3xin2k+9PsQuyS8Jcvfh/74yk/kFMNP5BTDT+QUw0/kFMNP5BTDT+QUw0/kVKHm+UNijucPCY235m1bWlpSawBw1llnRW07JMvHJXY+2xofmqeP3Ydg+PDhqbVZs2aZYx955BGzvmWLvU7N/Pnzzbr1uIWes9BzUim+8hM5xfATOcXwEznF8BM5xfATOcXwEznF8BM5FZznF5F1AOYB6FTVycl1jwH4BwBdyc0eVNVXK9lgVufez3qp6Z6entRaaIns0L8rtvcsj9mPnWuPOa49SxMmTIgav3TpUrM+c+ZMs97Q0JBay3r/h6/up4LbrAdw/QDXr1bVqclXRcEnouIIhl9VdwHorkMvRFRHMe8f7hGRd0RknYicV7OOiKguqg3/rwH8EMBUAIcArEy7oYi0ikhZRMpdXV1pNyOiOqsq/Kp6RFVPqOpJAL8BMN247VpVLalqqampqdo+iajGqgq/iIzt9+NPAbxXm3aIqF4qmerbBGAWgEYROQDgUQCzRGQqAAXQAWBJhj0SUQaC4VfVgRZBf6naDcace9+S9TrzMfcfe0x8b2+vWbfWuc97HwJr+7H7XsT0dtttt5n10PH6mzdvNuvHjh0z642NjWa9HriHH5FTDD+RUww/kVMMP5FTDD+RUww/kVOFOnV3zNRN7JTWmWeeadat6bQXX3zRHLt69WqzHurN2nZI7BRo6HEN1WN6z/KU5qHpV+u030C4t9Bh3vVahtvCV34ipxh+IqcYfiKnGH4ipxh+IqcYfiKnGH4ip+o+zx8zv2nNrcYeHjp+/HizPnv27NTajh07zLH79+8365MmTTLrIVk9pkC2p5HO8pDd0Ph9+/aZYzdu3GjWS6WSWT/nnHPMepanW68UX/mJnGL4iZxi+ImcYviJnGL4iZxi+ImcYviJnBpUx/NnOZ8duu+LL744tbZt2zZzbHt7u1m/9957zXqo95g549jTisfM1ccuXR5y8ODB1NoDDzxgjg3tv3DdddeZ9dGjR5v1IuArP5FTDD+RUww/kVMMP5FTDD+RUww/kVMMP5FTwXl+ERkHYAOAMQAUwFpV/ZWINAD4PYBmAB0AblHVT7JrNdtjoEPz1fPmzUuthc7L//TTT5v1Cy64wKwvXLjQrOd5Xv+Y5+Tzzz8368ePHzfrofUSNmzYkFo7cuSIOXbFihVmfdmyZWY9Zh+G2OekUpVspRfAL1X1EgBXAfi5iFwC4H4AO1V1EoCdyc9ENEgEw6+qh1R1d3L5MwDvA7gQwHwAbcnN2gDcmFWTRFR73+r9hYg0A5gG4M8AxqjqoaR0GH1/FhDRIFFx+EVkJIA/APiFqv61f037/mAe8I9mEWkVkbKIlLu6uqKaJaLaqSj8IjIMfcH/rapuTq4+IiJjk/pYAJ0DjVXVtapaUtVSU1NTLXomohoIhl/6PrZ8CcD7qrqqX2krgJbkcguALbVvj4iyUskc0Y8A/AzAuyLydnLdgwBWAPhXEVkMYB+AWyrZoDUFEnN4aGhsqB6aXpk5c2Zq7aGHHjLHLl++3KyHDuk9fPiwWbdOK37++eebY7u7u8166LDa559/3qxbhwyHDnU+duyYWQ89Z+PGjUutbd++3Rx79dVXm/VYRViiOxh+Vf0TgLTU/bi27RBRvXAPPyKnGH4ipxh+IqcYfiKnGH4ipxh+IqeknvONpVJJy+Vy1eN7e3tTa6FTTIcOsYw9hbXlySefNOtr1qwx652dA+48+RXrObTmugGgo6PDrMf8u0MmTJhg1qdPn27WFy9eXPX4kSNHmmOzXkLbes5itl0qlVAulyu6A77yEznF8BM5xfATOcXwEznF8BM5xfATOcXwEzlVqCW6Q/scxJyiOnRcesx8dqjv0HLQN9xwg1lfv369WX/hhRdSaxMnTjTHLliwwKz39PSY9VmzZpn1yy67LLU2atQoc2zsmZ+s5zx2Hj/m3BOx249duvwUvvITOcXwEznF8BM5xfATOcXwEznF8BM5xfATOTWojucnIhuP5yeiIIafyCmGn8gphp/IKYafyCmGn8gphp/IqWD4RWSciLwmIn8RkT0i8o/J9Y+JyEEReTv5mpt9u0RUK5WcHaMXwC9VdbeIjALwlojsSGqrVfWZ7NojoqwEw6+qhwAcSi5/JiLvA7gw68aIKFvf6m9+EWkGMA3An5Or7hGRd0RknYiclzKmVUTKIlLu6uqKapaIaqfi8IvISAB/APALVf0rgF8D+CGAqeh7Z7ByoHGqulZVS6paij0nGxHVTkXhF5Fh6Av+b1V1MwCo6hFVPaGqJwH8BoC9qiIRFUoln/YLgJcAvK+qq/pdP7bfzX4K4L3at0dEWank0/4fAfgZgHdF5O3kugcBLBKRqQAUQAeAJZl0SESZqOTT/j8BGOj44Fdr3w4R1Qv38CNyiuEncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEncqquS3SLSBeAff2uagRwtG4NfDtF7a2ofQHsrVq17G2CqlZ0vry6hv8bGxcpq2optwYMRe2tqH0B7K1aefXGt/1ETjH8RE7lHf61OW/fUtTeitoXwN6qlUtvuf7NT0T5yfuVn4hykkv4ReR6EfkfEflQRO7Po4c0ItIhIu8mKw+Xc+5lnYh0ish7/a5rEJEdIvJB8n3AZdJy6q0QKzcbK0vn+tgVbcXrur/tF5GhAP4XwE8AHADwJoBFqvqXujaSQkQ6AJRUNfc5YRGZCeA4gA2qOjm57ikA3aq6IvmP8zxVva8gvT0G4HjeKzcnC8qM7b+yNIAbAfw9cnzsjL5uQQ6PWx6v/NMBfKiqe1W1B8DvAMzPoY/CU9VdALpPu3o+gLbkchv6fnnqLqW3QlDVQ6q6O7n8GYBTK0vn+tgZfeUij/BfCGB/v58PoFhLfiuAP4rIWyLSmnczAxiTLJsOAIcBjMmzmQEEV26up9NWli7MY1fNite1xg/8vukaVb0cwBwAP0/e3haS9v3NVqTpmopWbq6XAVaW/kqej121K17XWh7hPwhgXL+fv59cVwiqejD53gmgHcVbffjIqUVSk++dOffzlSKt3DzQytIowGNXpBWv8wj/mwAmicgPROR7ABYC2JpDH98gIiOSD2IgIiMAzEbxVh/eCqAludwCYEuOvXxNUVZuTltZGjk/doVb8VpV6/4FYC76PvH/PwD/lEcPKX39LYD/Sr725N0bgE3oexv4Jfo+G1kM4HwAOwF8AOA/ADQUqLd/BvAugHfQF7SxOfV2Dfre0r8D4O3ka27ej53RVy6PG/fwI3KKH/gROcXwEznF8BM5xfATOcXwEznF8BM5xfATOcXwEzn1/6eTz7E34aJzAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始图像shape:  (28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEUtJREFUeJzt3X2MVGWWBvDngIzKl9p2L0EH6NmBaJQgaAnGUYKZHSJIghPUgGbsVdzGZISdZEz8WD8TjEQFnEScyCihWWaZWTN0+NDsDktMyJh1YklclXF3dbEJINBNWhmJJm3D2T/6Ylrse96y3rp1b3ueX9Lp6jr11j1U9UNV13vvfUVVQUT+DMm7ASLKB8NP5BTDT+QUw0/kFMNP5BTDT+QUw0/kFMNP5BTDT+TUGfXcWGNjozY3N9dzk0SudHR04OjRo1LJbaPCLyLXA/gVgKEAXlTVFdbtm5ubUS6XYzaZmZjdnEXsx/rkyZNmfciQuDdg1v2H7jv07w7920Ks+w/dd5a9xd531o9btUqlUsW3rfq3TkSGAlgDYA6ASwAsEpFLqr0/IqqvmJec6QA+VNW9qtoD4HcA5temLSLKWkz4LwSwv9/PB5LrvkZEWkWkLCLlrq6uiM0RUS1l/mm/qq5V1ZKqlpqamrLeHBFVKCb8BwGM6/fz95PriGgQiAn/mwAmicgPROR7ABYC2Fqbtogoa1VP9alqr4jcA+Df0TfVt05V99SsszrLcmomdrottp7VWCD8uOU15QUAvb29qbUzzrB/9Ys6lVdLUfP8qvoqgFdr1AsR1RF37yVyiuEncorhJ3KK4SdyiuEncorhJ3Kqrsfzf1fFzpWHxsce8muJnacPHa5sjc9629Zcfuxh1t+F/QD4yk/kFMNP5BTDT+QUw0/kFMNP5BTDT+QUp/rqIHbKKkbWZw6OGR/bW8x0moepvBC+8hM5xfATOcXwEznF8BM5xfATOcXwEznF8BM5Vfd5/phVW6u930rEbDt2znfo0KFm/fjx42b98ccfT62tXLnSHDts2DCz3tPTY9avvfZas/7cc8+l1qZMmWKOpWzxlZ/IKYafyCmGn8gphp/IKYafyCmGn8gphp/Iqah5fhHpAPAZgBMAelW1VIumquzFrMfuBxAjdNz63r17zfrNN99s1vfsSV8Z/dJLLzXHtra2mvXXX3/drL/88stmfcmSJam1Z555xhx71VVXmfXQ/hFffvllai20f8N34Xj9kFrs5HOdqh6twf0QUR3xbT+RU7HhVwB/FJG3RMR+/0hEhRL7tv8aVT0oIn8DYIeI/Leq7up/g+Q/hVYAGD9+fOTmiKhWol75VfVg8r0TQDuA6QPcZq2qllS11NTUFLM5IqqhqsMvIiNEZNSpywBmA3ivVo0RUbZi3vaPAdCeTImcAeBfVPXfatIVEWWu6vCr6l4Al9Wwlyh5nmf9xIkTZj00V37nnXea9Y6ODrN+++23p9ZWrVpljj333HPN+tKlS836FVdcYdYffvjh1NqcOXPMsa+99lrUtq0lukN43n4i+s5i+ImcYviJnGL4iZxi+ImcYviJnKr7qbvzPD13DGs6r7e31xz76KOPmvV9+/aZ9RkzZpj1NWvWpNbOPvtsc2zslNann35q1q1Tf4fue8uWLWZ92rRpZj3P07EPBnzlJ3KK4SdyiuEncorhJ3KK4SdyiuEncorhJ3JqUC3RneXca2i+2zpN9BNPPGGO3bVrl1m/6KKLzPq2bdvMujWXn/Whqffdd59Ztw5HfuWVV8yxTz31lFmfOHGiWb/ppptSa8OHDzfHhvbdiDlcuCj4yk/kFMNP5BTDT+QUw0/kFMNP5BTDT+QUw0/k1KA6nt8SOn12aDnnmPnw9vb2qG0vXrzYrDc0NJj1GKHlw0NGjx5t1jdu3Jha27Rpkzn2rrvuiqpbvxN33HGHOTY0j28t/13JeEu9ziXAV34ipxh+IqcYfiKnGH4ipxh+IqcYfiKnGH4ip4KTkSKyDsA8AJ2qOjm5rgHA7wE0A+gAcIuqfpJdm2FZzuMDwEcffZRaO3TokDk2tBT13XffbdZDrDnnYcOGmWOznlO27v/WW281x06ZMsWsX3755WZ92bJlqbUxY8aYY+fOnWvWQ4/rYFDJK/96ANefdt39AHaq6iQAO5OfiWgQCYZfVXcB6D7t6vkA2pLLbQBurHFfRJSxav/mH6Oqp97rHgZgv4ciosKJ/sBP+/6YTv2DWkRaRaQsIuWurq7YzRFRjVQb/iMiMhYAku+daTdU1bWqWlLVUlNTU5WbI6Jaqzb8WwG0JJdbANjLqRJR4QTDLyKbAPwngItE5ICILAawAsBPROQDAH+X/ExEg0hwnl9VF6WUflzjXjI9x3zsfPYbb7yRWuvuPn0y5OtmzJhh1keMGFFVT6fEzDmHjucfMsR+fYh5zkLnxp88ebJZb2trM+tLlixJrS1fvtwce+WVV5r1xsZGs16vY/JjcA8/IqcYfiKnGH4ipxh+IqcYfiKnGH4ipwq1znBoeiRmee/QlNQXX3xh1p999tmq73vBggVmPSRmOi00NjSVFxK6f0vsMteLFqXNQvexTqm+efNmc2zo1N7bt2836zGnRI99TireTl22QkSFw/ATOcXwEznF8BM5xfATOcXwEznF8BM5Vah5/izns0P7AXz88cdmfffu3am10Jzu+PHjzXrMXHlI1oeWZjknHXu48fr161Nrn3xin2k+9PsQuyS8Jcvfh/74yk/kFMNP5BTDT+QUw0/kFMNP5BTDT+QUw0/kVKHm+UNijucPCY235m1bWlpSawBw1llnRW07JMvHJXY+2xofmqeP3Ydg+PDhqbVZs2aZYx955BGzvmWLvU7N/Pnzzbr1uIWes9BzUim+8hM5xfATOcXwEznF8BM5xfATOcXwEznF8BM5FZznF5F1AOYB6FTVycl1jwH4BwBdyc0eVNVXK9lgVufez3qp6Z6entRaaIns0L8rtvcsj9mPnWuPOa49SxMmTIgav3TpUrM+c+ZMs97Q0JBay3r/h6/up4LbrAdw/QDXr1bVqclXRcEnouIIhl9VdwHorkMvRFRHMe8f7hGRd0RknYicV7OOiKguqg3/rwH8EMBUAIcArEy7oYi0ikhZRMpdXV1pNyOiOqsq/Kp6RFVPqOpJAL8BMN247VpVLalqqampqdo+iajGqgq/iIzt9+NPAbxXm3aIqF4qmerbBGAWgEYROQDgUQCzRGQqAAXQAWBJhj0SUQaC4VfVgRZBf6naDcace9+S9TrzMfcfe0x8b2+vWbfWuc97HwJr+7H7XsT0dtttt5n10PH6mzdvNuvHjh0z642NjWa9HriHH5FTDD+RUww/kVMMP5FTDD+RUww/kVOFOnV3zNRN7JTWmWeeadat6bQXX3zRHLt69WqzHurN2nZI7BRo6HEN1WN6z/KU5qHpV+u030C4t9Bh3vVahtvCV34ipxh+IqcYfiKnGH4ipxh+IqcYfiKnGH4ip+o+zx8zv2nNrcYeHjp+/HizPnv27NTajh07zLH79+8365MmTTLrIVk9pkC2p5HO8pDd0Ph9+/aZYzdu3GjWS6WSWT/nnHPMepanW68UX/mJnGL4iZxi+ImcYviJnGL4iZxi+ImcYviJnBpUx/NnOZ8duu+LL744tbZt2zZzbHt7u1m/9957zXqo95g549jTisfM1ccuXR5y8ODB1NoDDzxgjg3tv3DdddeZ9dGjR5v1IuArP5FTDD+RUww/kVMMP5FTDD+RUww/kVMMP5FTwXl+ERkHYAOAMQAUwFpV/ZWINAD4PYBmAB0AblHVT7JrNdtjoEPz1fPmzUuthc7L//TTT5v1Cy64wKwvXLjQrOd5Xv+Y5+Tzzz8368ePHzfrofUSNmzYkFo7cuSIOXbFihVmfdmyZWY9Zh+G2OekUpVspRfAL1X1EgBXAfi5iFwC4H4AO1V1EoCdyc9ENEgEw6+qh1R1d3L5MwDvA7gQwHwAbcnN2gDcmFWTRFR73+r9hYg0A5gG4M8AxqjqoaR0GH1/FhDRIFFx+EVkJIA/APiFqv61f037/mAe8I9mEWkVkbKIlLu6uqKaJaLaqSj8IjIMfcH/rapuTq4+IiJjk/pYAJ0DjVXVtapaUtVSU1NTLXomohoIhl/6PrZ8CcD7qrqqX2krgJbkcguALbVvj4iyUskc0Y8A/AzAuyLydnLdgwBWAPhXEVkMYB+AWyrZoDUFEnN4aGhsqB6aXpk5c2Zq7aGHHjLHLl++3KyHDuk9fPiwWbdOK37++eebY7u7u8166LDa559/3qxbhwyHDnU+duyYWQ89Z+PGjUutbd++3Rx79dVXm/VYRViiOxh+Vf0TgLTU/bi27RBRvXAPPyKnGH4ipxh+IqcYfiKnGH4ipxh+IqeknvONpVJJy+Vy1eN7e3tTa6FTTIcOsYw9hbXlySefNOtr1qwx652dA+48+RXrObTmugGgo6PDrMf8u0MmTJhg1qdPn27WFy9eXPX4kSNHmmOzXkLbes5itl0qlVAulyu6A77yEznF8BM5xfATOcXwEznF8BM5xfATOcXwEzlVqCW6Q/scxJyiOnRcesx8dqjv0HLQN9xwg1lfv369WX/hhRdSaxMnTjTHLliwwKz39PSY9VmzZpn1yy67LLU2atQoc2zsmZ+s5zx2Hj/m3BOx249duvwUvvITOcXwEznF8BM5xfATOcXwEznF8BM5xfATOTWojucnIhuP5yeiIIafyCmGn8gphp/IKYafyCmGn8gphp/IqWD4RWSciLwmIn8RkT0i8o/J9Y+JyEEReTv5mpt9u0RUK5WcHaMXwC9VdbeIjALwlojsSGqrVfWZ7NojoqwEw6+qhwAcSi5/JiLvA7gw68aIKFvf6m9+EWkGMA3An5Or7hGRd0RknYiclzKmVUTKIlLu6uqKapaIaqfi8IvISAB/APALVf0rgF8D+CGAqeh7Z7ByoHGqulZVS6paij0nGxHVTkXhF5Fh6Av+b1V1MwCo6hFVPaGqJwH8BoC9qiIRFUoln/YLgJcAvK+qq/pdP7bfzX4K4L3at0dEWank0/4fAfgZgHdF5O3kugcBLBKRqQAUQAeAJZl0SESZqOTT/j8BGOj44Fdr3w4R1Qv38CNyiuEncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEncqquS3SLSBeAff2uagRwtG4NfDtF7a2ofQHsrVq17G2CqlZ0vry6hv8bGxcpq2optwYMRe2tqH0B7K1aefXGt/1ETjH8RE7lHf61OW/fUtTeitoXwN6qlUtvuf7NT0T5yfuVn4hykkv4ReR6EfkfEflQRO7Po4c0ItIhIu8mKw+Xc+5lnYh0ish7/a5rEJEdIvJB8n3AZdJy6q0QKzcbK0vn+tgVbcXrur/tF5GhAP4XwE8AHADwJoBFqvqXujaSQkQ6AJRUNfc5YRGZCeA4gA2qOjm57ikA3aq6IvmP8zxVva8gvT0G4HjeKzcnC8qM7b+yNIAbAfw9cnzsjL5uQQ6PWx6v/NMBfKiqe1W1B8DvAMzPoY/CU9VdALpPu3o+gLbkchv6fnnqLqW3QlDVQ6q6O7n8GYBTK0vn+tgZfeUij/BfCGB/v58PoFhLfiuAP4rIWyLSmnczAxiTLJsOAIcBjMmzmQEEV26up9NWli7MY1fNite1xg/8vukaVb0cwBwAP0/e3haS9v3NVqTpmopWbq6XAVaW/kqej121K17XWh7hPwhgXL+fv59cVwiqejD53gmgHcVbffjIqUVSk++dOffzlSKt3DzQytIowGNXpBWv8wj/mwAmicgPROR7ABYC2JpDH98gIiOSD2IgIiMAzEbxVh/eCqAludwCYEuOvXxNUVZuTltZGjk/doVb8VpV6/4FYC76PvH/PwD/lEcPKX39LYD/Sr725N0bgE3oexv4Jfo+G1kM4HwAOwF8AOA/ADQUqLd/BvAugHfQF7SxOfV2Dfre0r8D4O3ka27ej53RVy6PG/fwI3KKH/gROcXwEznF8BM5xfATOcXwEznF8BM5xfATOcXwEzn1/6eTz7E34aJzAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "采样后图片shape:  (28, 28)\n"
     ]
    }
   ],
   "source": [
    "# 导入图像读取第三方库\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "img_path = './work/example_0.jpg'\n",
    "# 读取原始图像并显示\n",
    "im = Image.open('./work/example_0.jpg')\n",
    "plt.imshow(im)\n",
    "plt.show()\n",
    "# 将原始图像转为灰度图\n",
    "im = im.convert('L')\n",
    "print('原始图像shape: ', np.array(im).shape)\n",
    "# 使用Image.ANTIALIAS方式采样原始图片\n",
    "im = im.resize((28, 28), Image.ANTIALIAS)\n",
    "plt.imshow(im)\n",
    "plt.show()\n",
    "print(\"采样后图片shape: \", np.array(im).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "result Tensor(shape=[1, 1], dtype=float32, place=CPUPlace, stop_gradient=False,\n",
      "       [[1.11405170]])\n",
      "本次预测的数字是 [[1]]\n"
     ]
    }
   ],
   "source": [
    "# 读取一张本地的样例图片，转变成模型输入的格式\r\n",
    "def load_image(img_path):\r\n",
    "    # 从img_path中读取图像，并转为灰度图\r\n",
    "    im = Image.open(img_path).convert('L')\r\n",
    "    # print(np.array(im))\r\n",
    "    im = im.resize((28, 28), Image.ANTIALIAS)\r\n",
    "    im = np.array(im).reshape(1, -1).astype(np.float32)\r\n",
    "    # 图像归一化，保持和数据集的数据范围一致\r\n",
    "    im = 1 - im / 255\r\n",
    "    return im\r\n",
    "\r\n",
    "# 定义预测过程\r\n",
    "model = MNIST()\r\n",
    "params_file_path = 'mnist.pdparams'\r\n",
    "img_path = './work/example_0.jpg'\r\n",
    "# 加载模型参数\r\n",
    "param_dict = paddle.load(params_file_path)\r\n",
    "model.load_dict(param_dict)\r\n",
    "# 灌入数据\r\n",
    "model.eval()\r\n",
    "tensor_img = load_image(img_path)\r\n",
    "result = model(paddle.to_tensor(tensor_img))\r\n",
    "print('result',result)\r\n",
    "#  预测输出取整，即为预测的数字，打印结果\r\n",
    "print(\"本次预测的数字是\", result.numpy().astype('int32'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "从打印结果来看，模型预测出的数字是与实际输出的图片的数字不一致。这里只是验证了一个样本的情况，如果我们尝试更多的样本，可发现许多数字图片识别结果是错误的。因此完全复用房价预测的实验并不适用于手写数字识别任务！\n",
    "\n",
    "接下来我们会对手写数字识别实验模型进行逐一改进，直到获得令人满意的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 作业 2-1：\n",
    "\n",
    "1. 使用飞桨API [paddle.vision.datasets.MNIST](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/vision/datasets/mnist/MNIST_cn.html)的mode函数获得测试集数据，计算当前模型的准确率。\n",
    "\n",
    "2. 怎样进一步提高模型的准确率？可以在接下来内容开始前，写出你想到的优化思路。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
